{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trial 2: classification with learned graph filters\n",
    "\n",
    "We want to classify data by first extracting meaningful features from learned filters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import scipy.sparse, scipy.sparse.linalg, scipy.spatial.distance\n",
    "from sklearn import datasets, linear_model\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from lib import graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset\n",
    "\n",
    "* Two digits version of MNIST with N samples of each class.\n",
    "* Distinguishing 4 from 9 is the hardest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def mnist(a, b, N):\n",
    "    \"\"\"Prepare data for binary classification of MNIST.\"\"\"\n",
    "    folder = os.path.join('..', 'data')\n",
    "    mnist = datasets.fetch_mldata('MNIST original', data_home=folder)\n",
    "\n",
    "    assert N < min(sum(mnist.target==a), sum(mnist.target==b))\n",
    "    M = mnist.data.shape[1]\n",
    "    \n",
    "    X = np.empty((M, 2, N))\n",
    "    X[:,0,:] = mnist.data[mnist.target==a,:][:N,:].T\n",
    "    X[:,1,:] = mnist.data[mnist.target==b,:][:N,:].T\n",
    "    \n",
    "    y = np.empty((2, N))\n",
    "    y[0,:] = -1\n",
    "    y[1,:] = +1\n",
    "\n",
    "    X.shape = M, 2*N\n",
    "    y.shape = 2*N, 1\n",
    "    return X, y\n",
    "\n",
    "X, y = mnist(4, 9, 1000)\n",
    "\n",
    "print('Dimensionality: N={} samples, M={} features'.format(X.shape[1], X.shape[0]))\n",
    "\n",
    "X -= 127.5\n",
    "print('X in [{}, {}]'.format(np.min(X), np.max(X)))\n",
    "\n",
    "def plot_digit(nn):\n",
    "    M, N = X.shape\n",
    "    m = int(np.sqrt(M))\n",
    "    fig, axes = plt.subplots(1,len(nn), figsize=(15,5))\n",
    "    for i, n in enumerate(nn):\n",
    "        n = int(n)\n",
    "        img = X[:,n]\n",
    "        axes[i].imshow(img.reshape((m,m)))\n",
    "        axes[i].set_title('Label: y = {:.0f}'.format(y[n,0]))\n",
    "\n",
    "plot_digit([0, 1, 1e2, 1e2+1, 1e3, 1e3+1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regularized least-square\n",
    "\n",
    "## Reference: sklearn ridge regression\n",
    "\n",
    "* With regularized data, the objective is the same with or without bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def test_sklearn(tauR):\n",
    "    \n",
    "    def L(w, b=0):\n",
    "        return np.linalg.norm(X.T @ w + b - y)**2 + tauR * np.linalg.norm(w)**2\n",
    "\n",
    "    def dL(w):\n",
    "        return 2 * X @ (X.T @ w - y) + 2 * tauR * w\n",
    "\n",
    "    clf = linear_model.Ridge(alpha=tauR, fit_intercept=False)\n",
    "    clf.fit(X.T, y)\n",
    "    w = clf.coef_.T\n",
    "\n",
    "    print('L = {}'.format(L(w, clf.intercept_)))\n",
    "    print('|dLw| = {}'.format(np.linalg.norm(dL(w))))\n",
    "\n",
    "    # Normalized data: intercept should be small.\n",
    "    print('bias: {}'.format(abs(np.mean(y - X.T @ w))))\n",
    "\n",
    "test_sklearn(1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def test_optim(clf, X, y, ax=None):\n",
    "    \"\"\"Test optimization on full dataset.\"\"\"\n",
    "    tstart = time.process_time()\n",
    "    ret = clf.fit(X, y)\n",
    "    print('Processing time: {}'.format(time.process_time()-tstart))\n",
    "    print('L = {}'.format(clf.L(*ret, y)))\n",
    "    if hasattr(clf, 'dLc'):\n",
    "        print('|dLc| = {}'.format(np.linalg.norm(clf.dLc(*ret, y))))\n",
    "    if hasattr(clf, 'dLw'):\n",
    "        print('|dLw| = {}'.format(np.linalg.norm(clf.dLw(*ret, y))))\n",
    "    if hasattr(clf, 'loss'):\n",
    "        if not ax:\n",
    "            fig = plt.figure()\n",
    "            ax = fig.add_subplot(111)\n",
    "        ax.semilogy(clf.loss)\n",
    "        ax.set_title('Convergence')\n",
    "        ax.set_xlabel('Iteration number')\n",
    "        ax.set_ylabel('Loss')\n",
    "    if hasattr(clf, 'Lsplit'):\n",
    "        print('Lsplit = {}'.format(clf.Lsplit(*ret, y)))\n",
    "        print('|dLz| = {}'.format(np.linalg.norm(clf.dLz(*ret, y))))\n",
    "        ax.semilogy(clf.loss_split)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class rls:\n",
    "    \n",
    "    def __init__(s, tauR, algo='solve'):\n",
    "        s.tauR = tauR\n",
    "        if algo is 'solve':\n",
    "            s.fit = s.solve\n",
    "        elif algo is 'inv':\n",
    "            s.fit = s.inv\n",
    "\n",
    "    def L(s, X, y):\n",
    "        return np.linalg.norm(X.T @ s.w - y)**2 + s.tauR * np.linalg.norm(s.w)**2\n",
    "\n",
    "    def dLw(s, X, y):\n",
    "        return 2 * X @ (X.T @ s.w - y) + 2 * s.tauR * s.w\n",
    "    \n",
    "    def inv(s, X, y):\n",
    "        s.w = np.linalg.inv(X @ X.T + s.tauR * np.identity(X.shape[0])) @ X @ y\n",
    "        return (X,)\n",
    "    \n",
    "    def solve(s, X, y):\n",
    "        s.w = np.linalg.solve(X @ X.T + s.tauR * np.identity(X.shape[0]), X @ y)\n",
    "        return (X,)\n",
    "    \n",
    "    def predict(s, X):\n",
    "        return X.T @ s.w\n",
    "\n",
    "test_optim(rls(1e-3, 'solve'), X, y)\n",
    "test_optim(rls(1e-3, 'inv'), X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "t_start = time.process_time()\n",
    "z = graph.grid(int(np.sqrt(X.shape[0])))\n",
    "dist, idx = graph.distance_sklearn_metrics(z, k=4)\n",
    "A = graph.adjacency(dist, idx)\n",
    "L = graph.laplacian(A, True)\n",
    "lmax = graph.lmax(L)\n",
    "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lanczos basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def lanczos(L, X, K):\n",
    "    M, N = X.shape\n",
    "    a = np.empty((K, N))\n",
    "    b = np.zeros((K, N))\n",
    "    V = np.empty((K, M, N))\n",
    "    V[0,...] = X / np.linalg.norm(X, axis=0)\n",
    "    for k in range(K-1):\n",
    "        W = L.dot(V[k,...])\n",
    "        a[k,:] = np.sum(W * V[k,...], axis=0)\n",
    "        W = W - a[k,:] * V[k,...] - (b[k,:] * V[k-1,...] if k>0 else 0)\n",
    "        b[k+1,:] = np.linalg.norm(W, axis=0)\n",
    "        V[k+1,...] = W / b[k+1,:]\n",
    "    a[K-1,:] = np.sum(L.dot(V[K-1,...]) * V[K-1,...], axis=0)\n",
    "    return V, a, b\n",
    "\n",
    "def lanczos_H_diag(a, b):\n",
    "    K, N = a.shape\n",
    "    H = np.zeros((K*K, N))\n",
    "    H[:K**2:K+1, :] = a\n",
    "    H[1:(K-1)*K:K+1, :] = b[1:,:]\n",
    "    H.shape = (K, K, N)\n",
    "    Q = np.linalg.eigh(H.T, UPLO='L')[1]\n",
    "    Q = np.swapaxes(Q,1,2).T\n",
    "    return Q\n",
    "\n",
    "def lanczos_basis_eval(L, X, K):\n",
    "    V, a, b = lanczos(L, X, K)\n",
    "    Q = lanczos_H_diag(a, b)\n",
    "    M, N = X.shape\n",
    "    Xt = np.empty((K, M, N))\n",
    "    for n in range(N):\n",
    "        Xt[...,n] = Q[...,n].T @ V[...,n]\n",
    "    Xt *= Q[0,:,np.newaxis,:]\n",
    "    Xt *= np.linalg.norm(X, axis=0)\n",
    "    return Xt, Q[0,...]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tests\n",
    "\n",
    "* Memory arrangement for fastest computations: largest dimensions on the outside, i.e. fastest varying indices.\n",
    "* The einsum seems to be efficient for three operands."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def test():\n",
    "    \"\"\"Test the speed of filtering and weighting.\"\"\"\n",
    "    \n",
    "    def mult(impl=3):\n",
    "        if impl is 0:\n",
    "            Xb = Xt.view()\n",
    "            Xb.shape = (K, M*N)\n",
    "            XCb = Xb.T @ C  # in MN x F\n",
    "            XCb = XCb.T.reshape((F*M, N))\n",
    "            return (XCb.T @ w).squeeze()\n",
    "        elif impl is 1:\n",
    "            tmp = np.tensordot(Xt, C, (0,0))\n",
    "            return np.tensordot(tmp, W, ((0,2),(1,0)))\n",
    "        elif impl is 2:\n",
    "            tmp = np.tensordot(Xt, C, (0,0))\n",
    "            return np.einsum('ijk,ki->j', tmp, W)\n",
    "        elif impl is 3:\n",
    "            return np.einsum('kmn,fm,kf->n', Xt, W, C)\n",
    "    \n",
    "    C = np.random.normal(0,1,(K,F))\n",
    "    W = np.random.normal(0,1,(F,M))\n",
    "    w = W.reshape((F*M, 1))\n",
    "    a = mult(impl=0)\n",
    "    for impl in range(4):\n",
    "        tstart = time.process_time()\n",
    "        for k in range(1000):\n",
    "            b = mult(impl)\n",
    "        print('Execution time (impl={}): {}'.format(impl, time.process_time() - tstart))\n",
    "        np.testing.assert_allclose(a, b)\n",
    "#test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GFL classification without weights\n",
    "\n",
    "* The matrix is singular thus not invertible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class gflc_noweights:\n",
    "\n",
    "    def __init__(s, F, K, niter, algo='direct'):\n",
    "        \"\"\"Model hyper-parameters\"\"\"\n",
    "        s.F = F\n",
    "        s.K = K\n",
    "        s.niter = niter\n",
    "        if algo is 'direct':\n",
    "            s.fit = s.direct\n",
    "        elif algo is 'sgd':\n",
    "            s.fit = s.sgd\n",
    "    \n",
    "    def L(s, Xt, y):\n",
    "        #tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, np.ones((s.F,M))) - y.squeeze()\n",
    "        #tmp = np.einsum('kmn,kf->mnf', Xt, s.C).sum((0,2)) - y.squeeze()\n",
    "        #tmp = (C.T @ Xt.reshape((K,M*N))).reshape((F,M,N)).sum((0,2)) - y.squeeze()\n",
    "        tmp = np.tensordot(s.C, Xt, (0,0)).sum((0,1)) - y.squeeze()\n",
    "        return np.linalg.norm(tmp)**2\n",
    "\n",
    "    def dLc(s, Xt, y):\n",
    "        tmp = np.tensordot(s.C, Xt, (0,0)).sum(axis=(0,1)) - y.squeeze()\n",
    "        return np.dot(Xt, tmp).sum(1)[:,np.newaxis].repeat(s.F,1)\n",
    "        #return np.einsum('kmn,n->km', Xt, tmp).sum(1)[:,np.newaxis].repeat(s.F,1)\n",
    "\n",
    "    def sgd(s, X, y):\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.random.normal(0, 1, (s.K, s.F))\n",
    "        s.loss = [s.L(Xt, y)]\n",
    "        for t in range(s.niter):\n",
    "            s.C -= 1e-13 * s.dLc(Xt, y)\n",
    "            s.loss.append(s.L(Xt, y))\n",
    "        return (Xt,)\n",
    "    \n",
    "    def direct(s, X, y):\n",
    "        M, N = X.shape\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.random.normal(0, 1, (s.K, s.F))\n",
    "        W = np.ones((s.F, M))\n",
    "        c = s.C.reshape((s.K*s.F, 1))\n",
    "        s.loss = [s.L(Xt, y)]\n",
    "        Xw = np.einsum('kmn,fm->kfn', Xt, W)\n",
    "        #Xw = np.tensordot(Xt, W, (1,1))\n",
    "        Xw.shape = (s.K*s.F, N)\n",
    "        #np.linalg.inv(Xw @ Xw.T)\n",
    "        c[:] = np.linalg.solve(Xw @ Xw.T, Xw @ y)\n",
    "        s.loss.append(s.L(Xt, y))\n",
    "        return (Xt,)\n",
    "\n",
    "#test_optim(gflc_noweights(1, 4, 100, 'sgd'), X, y)\n",
    "#test_optim(gflc_noweights(1, 4, 0, 'direct'), X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GFL classification with weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class gflc_weights():\n",
    "\n",
    "    def __init__(s, F, K, tauR, niter, algo='direct'):\n",
    "        \"\"\"Model hyper-parameters\"\"\"\n",
    "        s.F = F\n",
    "        s.K = K\n",
    "        s.tauR = tauR\n",
    "        s.niter = niter\n",
    "        if algo is 'direct':\n",
    "            s.fit = s.direct\n",
    "        elif algo is 'sgd':\n",
    "            s.fit = s.sgd\n",
    "\n",
    "    def L(s, Xt, y):\n",
    "        tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n",
    "        return np.linalg.norm(tmp)**2 + s.tauR * np.linalg.norm(s.W)**2\n",
    "\n",
    "    def dLw(s, Xt, y):\n",
    "        tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n",
    "        return 2 * np.einsum('kmn,kf,n->fm', Xt, s.C, tmp) + 2 * s.tauR * s.W\n",
    "\n",
    "    def dLc(s, Xt, y):\n",
    "        tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n",
    "        return 2 * np.einsum('kmn,n,fm->kf', Xt, tmp, s.W)\n",
    "\n",
    "    def sgd(s, X, y):\n",
    "        M, N = X.shape\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.random.normal(0, 1, (s.K, s.F))\n",
    "        s.W = np.random.normal(0, 1, (s.F, M))\n",
    "\n",
    "        s.loss = [s.L(Xt, y)]\n",
    "\n",
    "        for t in range(s.niter):\n",
    "            s.C -= 1e-12 * s.dLc(Xt, y)\n",
    "            s.W -= 1e-12 * s.dLw(Xt, y)\n",
    "            s.loss.append(s.L(Xt, y))\n",
    "        \n",
    "        return (Xt,)\n",
    "\n",
    "    def direct(s, X, y):\n",
    "        M, N = X.shape\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.random.normal(0, 1, (s.K, s.F))\n",
    "        s.W = np.random.normal(0, 1, (s.F, M))\n",
    "        #c = s.C.reshape((s.K*s.F, 1))\n",
    "        #w = s.W.reshape((s.F*M, 1))\n",
    "        c = s.C.view()\n",
    "        c.shape = (s.K*s.F, 1)\n",
    "        w = s.W.view()\n",
    "        w.shape = (s.F*M, 1)\n",
    "\n",
    "        s.loss = [s.L(Xt, y)]\n",
    "\n",
    "        for t in range(s.niter):\n",
    "            Xw = np.einsum('kmn,fm->kfn', Xt, s.W)\n",
    "            #Xw = np.tensordot(Xt, s.W, (1,1))\n",
    "            Xw.shape = (s.K*s.F, N)\n",
    "            c[:] = np.linalg.solve(Xw @ Xw.T, Xw @ y)\n",
    "\n",
    "            Z = np.einsum('kmn,kf->fmn', Xt, s.C)\n",
    "            #Z = np.tensordot(Xt, s.C, (0,0))\n",
    "            #Z = s.C.T @ Xt.reshape((K,M*N))\n",
    "            Z.shape = (s.F*M, N)\n",
    "            w[:] = np.linalg.solve(Z @ Z.T + s.tauR * np.identity(s.F*M), Z @ y)\n",
    "\n",
    "            s.loss.append(s.L(Xt, y))\n",
    "        \n",
    "        return (Xt,)\n",
    "\n",
    "    def predict(s, X):\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        return np.einsum('kmn,kf,fm->n', Xt, s.C, s.W)\n",
    "\n",
    "#test_optim(gflc_weights(3, 4, 1e-3, 50, 'sgd'), X, y)\n",
    "clf_weights = gflc_weights(F=3, K=50, tauR=1e4, niter=5, algo='direct')\n",
    "test_optim(clf_weights, X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GFL classification with splitting\n",
    "\n",
    "Solvers\n",
    "* Closed-form solution.\n",
    "* Stochastic gradient descent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class gflc_split():\n",
    "\n",
    "    def __init__(s, F, K, tauR, tauF, niter, algo='direct'):\n",
    "        \"\"\"Model hyper-parameters\"\"\"\n",
    "        s.F = F\n",
    "        s.K = K\n",
    "        s.tauR = tauR\n",
    "        s.tauF = tauF\n",
    "        s.niter = niter\n",
    "        if algo is 'direct':\n",
    "            s.fit = s.direct\n",
    "        elif algo is 'sgd':\n",
    "            s.fit = s.sgd\n",
    "\n",
    "    def L(s, Xt, XCb, Z, y):\n",
    "        return np.linalg.norm(XCb.T @ s.w - y)**2 + s.tauR * np.linalg.norm(s.w)**2\n",
    "\n",
    "    def Lsplit(s, Xt, XCb, Z, y):\n",
    "        return np.linalg.norm(Z.T @ s.w - y)**2 + s.tauF * np.linalg.norm(XCb - Z)**2 + s.tauR * np.linalg.norm(s.w)**2\n",
    "\n",
    "    def dLw(s, Xt, XCb, Z, y):\n",
    "        return 2 * Z @ (Z.T @ s.w - y) + 2 * s.tauR * s.w\n",
    "\n",
    "    def dLc(s, Xt, XCb, Z, y):\n",
    "        Xb = Xt.reshape((s.K, -1)).T\n",
    "        Zb = Z.reshape((s.F, -1)).T\n",
    "        return 2 * s.tauF * Xb.T @ (Xb @ s.C - Zb)\n",
    "\n",
    "    def dLz(s, Xt, XCb, Z, y):\n",
    "        return 2 * s.w @ (s.w.T @ Z - y.T) + 2 * s.tauF * (Z - XCb)\n",
    "\n",
    "    def lanczos_filter(s, Xt):\n",
    "        M, N = Xt.shape[1:]\n",
    "        Xb = Xt.reshape((s.K, M*N)).T\n",
    "        #XCb = np.tensordot(Xb, C, (2,1))\n",
    "        XCb = Xb @ s.C  # in MN x F\n",
    "        XCb = XCb.T.reshape((s.F*M, N))  # Needs to copy data.\n",
    "        return XCb\n",
    "\n",
    "    def sgd(s, X, y):\n",
    "        M, N = X.shape\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.zeros((s.K, s.F))\n",
    "        s.w = np.zeros((s.F*M, 1))\n",
    "        Z = np.random.normal(0, 1, (s.F*M, N))\n",
    "\n",
    "        XCb = np.empty((s.F*M, N))\n",
    "\n",
    "        s.loss = [s.L(Xt, XCb, Z, y)]\n",
    "        s.loss_split = [s.Lsplit(Xt, XCb, Z, y)]\n",
    "\n",
    "        for t in range(s.niter):\n",
    "            s.C -= 1e-7 * s.dLc(Xt, XCb, Z, y)\n",
    "            XCb[:] = s.lanczos_filter(Xt)\n",
    "            Z -= 1e-4 * s.dLz(Xt, XCb, Z, y)\n",
    "            s.w -= 1e-4 * s.dLw(Xt, XCb, Z, y)\n",
    "            s.loss.append(s.L(Xt, XCb, Z, y))\n",
    "            s.loss_split.append(s.Lsplit(Xt, XCb, Z, y))\n",
    "        \n",
    "        return Xt, XCb, Z\n",
    "\n",
    "    def direct(s, X, y):\n",
    "        M, N = X.shape\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        s.C = np.zeros((s.K, s.F))\n",
    "        s.w = np.zeros((s.F*M, 1))\n",
    "        Z = np.random.normal(0, 1, (s.F*M, N))\n",
    "\n",
    "        XCb = np.empty((s.F*M, N))\n",
    "        Xb = Xt.reshape((s.K, M*N)).T\n",
    "        Zb = Z.reshape((s.F, M*N)).T\n",
    "\n",
    "        s.loss = [s.L(Xt, XCb, Z, y)]\n",
    "        s.loss_split = [s.Lsplit(Xt, XCb, Z, y)]\n",
    "\n",
    "        for t in range(s.niter):\n",
    "\n",
    "            s.C[:] = Xb.T @ Zb / np.sum((np.linalg.norm(X, axis=0) * q)**2, axis=1)[:,np.newaxis]\n",
    "            XCb[:] = s.lanczos_filter(Xt)\n",
    "\n",
    "            #Z[:] = np.linalg.inv(s.tauF * np.identity(s.F*M) + s.w @ s.w.T) @ (s.tauF * XCb + s.w @ y.T)\n",
    "            Z[:] = np.linalg.solve(s.tauF * np.identity(s.F*M) + s.w @ s.w.T, s.tauF * XCb + s.w @ y.T)\n",
    "\n",
    "            #s.w[:] = np.linalg.inv(Z @ Z.T + s.tauR * np.identity(s.F*M)) @ Z @ y\n",
    "            s.w[:] = np.linalg.solve(Z @ Z.T + s.tauR * np.identity(s.F*M), Z @ y)\n",
    "\n",
    "            s.loss.append(s.L(Xt, XCb, Z, y))\n",
    "            s.loss_split.append(s.Lsplit(Xt, XCb, Z, y))\n",
    "        \n",
    "        return Xt, XCb, Z\n",
    "\n",
    "    def predict(s, X):\n",
    "        Xt, q = lanczos_basis_eval(L, X, s.K)\n",
    "        XCb = s.lanczos_filter(Xt)\n",
    "        return XCb.T @ s.w\n",
    "\n",
    "#test_optim(gflc_split(3, 4, 1e-3, 1e-3, 50, 'sgd'), X, y)\n",
    "clf_split = gflc_split(3, 4, 1e4, 1e-3, 8, 'direct')\n",
    "test_optim(clf_split, X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Filters visualization\n",
    "\n",
    "Observations:\n",
    "* Filters learned with the splitting scheme have much smaller amplitudes.\n",
    "* Maybe the energy sometimes goes in W ?\n",
    "* Why are the filters so different ?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "lamb, U = graph.fourier(L)\n",
    "print('Spectrum in [{:1.2e}, {:1.2e}]'.format(lamb[0], lamb[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def plot_filters(C, spectrum=False):\n",
    "    K, F = C.shape\n",
    "    M, M = L.shape\n",
    "    m = int(np.sqrt(M))\n",
    "    X = np.zeros((M,1))\n",
    "    X[int(m/2*(m+1))] = 1  # Kronecker\n",
    "    Xt, q = lanczos_basis_eval(L, X, K)\n",
    "    Z = np.einsum('kmn,kf->mnf', Xt, C)\n",
    "    Xh = U.T @ X\n",
    "    Zh = np.tensordot(U.T, Z, (1,0))\n",
    "    \n",
    "    pmin = int(m/2) - K\n",
    "    pmax = int(m/2) + K + 1\n",
    "    fig, axes = plt.subplots(2,int(np.ceil(F/2)), figsize=(15,5))\n",
    "    for f in range(F):\n",
    "        img = Z[:,0,f].reshape((m,m))[pmin:pmax,pmin:pmax]\n",
    "        im = axes.flat[f].imshow(img, vmin=Z.min(), vmax=Z.max(), interpolation='none')\n",
    "        axes.flat[f].set_title('Filter {}'.format(f))\n",
    "    fig.subplots_adjust(right=0.8)\n",
    "    cax = fig.add_axes([0.82, 0.16, 0.02, 0.7])\n",
    "    fig.colorbar(im, cax=cax)\n",
    "    \n",
    "    if spectrum:\n",
    "        ax = plt.figure(figsize=(15,5)).add_subplot(111)\n",
    "        for f in range(F):\n",
    "            ax.plot(lamb, Zh[...,f] / Xh, '.-', label='Filter {}'.format(f))\n",
    "        ax.legend(loc='best')\n",
    "        ax.set_title('Spectrum of learned filters')\n",
    "        ax.set_xlabel('Frequency')\n",
    "        ax.set_ylabel('Amplitude')\n",
    "        ax.set_xlim(0, lmax)\n",
    "\n",
    "plot_filters(clf_weights.C, True)\n",
    "plot_filters(clf_split.C, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extracted features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def plot_features(C, x):\n",
    "    K, F = C.shape\n",
    "    m = int(np.sqrt(x.shape[0]))\n",
    "    xt, q = lanczos_basis_eval(L, x, K)\n",
    "    Z = np.einsum('kmn,kf->mnf', xt, C)\n",
    "    \n",
    "    fig, axes = plt.subplots(2,int(np.ceil(F/2)), figsize=(15,5))\n",
    "    for f in range(F):\n",
    "        img = Z[:,0,f].reshape((m,m))\n",
    "        #im = axes.flat[f].imshow(img, vmin=Z.min(), vmax=Z.max(), interpolation='none')\n",
    "        im = axes.flat[f].imshow(img, interpolation='none')\n",
    "        axes.flat[f].set_title('Filter {}'.format(f))\n",
    "    fig.subplots_adjust(right=0.8)\n",
    "    cax = fig.add_axes([0.82, 0.16, 0.02, 0.7])\n",
    "    fig.colorbar(im, cax=cax)\n",
    "\n",
    "plot_features(clf_weights.C, X[:,[0]])\n",
    "plot_features(clf_weights.C, X[:,[1000]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Performance w.r.t. hyper-parameters\n",
    "\n",
    "* F plays a big role.\n",
    "    * Both for performance and training time.\n",
    "    * Larger values lead to over-fitting !\n",
    "* Order $K \\in [3,5]$ seems sufficient.\n",
    "* $\\tau_R$ does not have much influence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def scorer(clf, X, y):\n",
    "    yest = clf.predict(X).round().squeeze()\n",
    "    y = y.squeeze()\n",
    "    yy = np.ones(len(y))\n",
    "    yy[yest < 0] = -1\n",
    "    nerrs = np.count_nonzero(y - yy)\n",
    "    return 1 - nerrs / len(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def perf(clf, nfolds=3):\n",
    "    \"\"\"Test training accuracy.\"\"\"\n",
    "    N = X.shape[1]\n",
    "    inds = np.arange(N)\n",
    "    np.random.shuffle(inds)\n",
    "    inds.resize((nfolds, int(N/nfolds)))\n",
    "    folds = np.arange(nfolds)\n",
    "    test = inds[0,:]\n",
    "    train = inds[folds != 0, :].reshape(-1)\n",
    "    \n",
    "    fig, axes = plt.subplots(1,3, figsize=(15,5))\n",
    "    test_optim(clf, X[:,train], y[train], axes[2])\n",
    "    \n",
    "    axes[0].plot(train, clf.predict(X[:,train]), '.')\n",
    "    axes[0].plot(train, y[train].squeeze(), '.')\n",
    "    axes[0].set_ylim([-3,3])\n",
    "    axes[0].set_title('Training set accuracy: {:.2f}'.format(scorer(clf, X[:,train], y[train])))\n",
    "    axes[1].plot(test, clf.predict(X[:,test]), '.')\n",
    "    axes[1].plot(test, y[test].squeeze(), '.')\n",
    "    axes[1].set_ylim([-3,3])\n",
    "    axes[1].set_title('Testing set accuracy: {:.2f}'.format(scorer(clf, X[:,test], y[test])))\n",
    "    \n",
    "    if hasattr(clf, 'C'):\n",
    "        plot_filters(clf.C)\n",
    "\n",
    "perf(rls(tauR=1e6))\n",
    "for F in [1,3,5]:\n",
    "    perf(gflc_weights(F=F, K=50, tauR=1e4, niter=5, algo='direct'))\n",
    "\n",
    "#perf(rls(tauR=1e-3))\n",
    "#for K in [2,3,5,7]:\n",
    "#    perf(gflc_weights(F=3, K=K, tauR=1e-3, niter=5, algo='direct'))\n",
    "\n",
    "#for tauR in [1e-3, 1e-1, 1e1]:\n",
    "#    perf(rls(tauR=tauR))\n",
    "#    perf(gflc_weights(F=3, K=3, tauR=tauR, niter=5, algo='direct'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification\n",
    "\n",
    "* Greater is $F$, greater should $K$ be."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cross_validation(clf, nfolds, nvalidations):\n",
    "    M, N = X.shape\n",
    "    scores = np.empty((nvalidations, nfolds))\n",
    "    for nval in range(nvalidations):\n",
    "        inds = np.arange(N)\n",
    "        np.random.shuffle(inds)\n",
    "        inds.resize((nfolds, int(N/nfolds)))\n",
    "        folds = np.arange(nfolds)\n",
    "        for n in folds:\n",
    "            test = inds[n,:]\n",
    "            train = inds[folds != n, :].reshape(-1)\n",
    "            clf.fit(X[:,train], y[train])\n",
    "            scores[nval, n] = scorer(clf, X[:,test], y[test])\n",
    "    return scores.mean()*100, scores.std()*100\n",
    "    #print('Accuracy: {:.2f} +- {:.2f}'.format(scores.mean()*100, scores.std()*100))\n",
    "    #print(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def test_classification(clf, params, param, values, nfolds=10, nvalidations=1):\n",
    "    means = []\n",
    "    stds = []\n",
    "    fig, ax = plt.subplots(1,1, figsize=(15,5))\n",
    "    for i,val in enumerate(values):\n",
    "        params[param] = val\n",
    "        mean, std = cross_validation(clf(**params), nfolds, nvalidations)\n",
    "        means.append(mean)\n",
    "        stds.append(std)\n",
    "        ax.annotate('{:.2f} +- {:.2f}'.format(mean,std), xy=(i,mean), xytext=(10,10), textcoords='offset points')\n",
    "    ax.errorbar(np.arange(len(values)), means, stds, fmt='.', markersize=10)\n",
    "    ax.set_xlim(-.8, len(values)-.2)\n",
    "    ax.set_xticks(np.arange(len(values)))\n",
    "    ax.set_xticklabels(values)\n",
    "    ax.set_xlabel(param)\n",
    "    ax.set_ylim(50, 100)\n",
    "    ax.set_ylabel('Accuracy')\n",
    "    ax.set_title('Parameters: {}'.format(params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_classification(rls, {}, 'tauR', [1e8,1e7,1e6,1e5,1e4,1e3,1e-5,1e-8], 10, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':1, 'K':2, 'tauR':1e3, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'tauR', [1e8,1e6,1e5,1e4,1e3,1e2,1e-3,1e-8], 10, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':2, 'K':10, 'tauR':1e4, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'F', [1,2,3,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':2, 'K':4, 'tauR':1e4, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'K', [2,3,4,5,8,10,20,30,50,70])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sampled MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Xfull = X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def sample(X, p, seed=None):\n",
    "    M, N = X.shape\n",
    "    z = graph.grid(int(np.sqrt(M)))\n",
    "    \n",
    "    # Select random pixels.\n",
    "    np.random.seed(seed)\n",
    "    mask = np.arange(M)\n",
    "    np.random.shuffle(mask)\n",
    "    mask = mask[:int(p*M)]\n",
    "    \n",
    "    return z[mask,:], X[mask,:]\n",
    "\n",
    "X = Xfull\n",
    "z, X = sample(X, .5)\n",
    "dist, idx = graph.distance_sklearn_metrics(z, k=4)\n",
    "A = graph.adjacency(dist, idx)\n",
    "L = graph.laplacian(A)\n",
    "lmax = graph.lmax(L)\n",
    "lamb, U = graph.fourier(L)\n",
    "print('Spectrum in [{:1.2e}, {:1.2e}]'.format(lamb[0], lamb[-1]))\n",
    "\n",
    "print(L.shape)\n",
    "\n",
    "def plot(n):\n",
    "    M, N = X.shape\n",
    "    m = int(np.sqrt(M))\n",
    "    x = X[:,n]\n",
    "    #print(x+127.5)\n",
    "    plt.scatter(z[:,0], -z[:,1], s=20, c=x+127.5)\n",
    "plot(10)\n",
    "\n",
    "def plot_digit(nn):\n",
    "    M, N = X.shape\n",
    "    m = int(np.sqrt(M))\n",
    "    fig, axes = plt.subplots(1,len(nn), figsize=(15,5))\n",
    "    for i, n in enumerate(nn):\n",
    "        n = int(n)\n",
    "        img = X[:,n]\n",
    "        axes[i].imshow(img.reshape((m,m)))\n",
    "        axes[i].set_title('Label: y = {:.0f}'.format(y[n,0]))\n",
    "\n",
    "#plot_digit([0, 1, 1e2, 1e2+1, 1e3, 1e3+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#clf_weights = gflc_weights(F=3, K=4, tauR=1e-3, niter=5, algo='direct')\n",
    "#test_optim(clf_weights, X, y)\n",
    "#plot_filters(clf_weights.C, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#test_classification(rls, {}, 'tauR', [1e1,1e0])\n",
    "#params = {'F':2, 'K':5, 'tauR':1e-3, 'niter':5, 'algo':'direct'}\n",
    "#test_classification(gflc_weights, params, 'F', [1,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_classification(rls, {}, 'tauR', [1e8,1e7,1e6,1e5,1e4,1e3,1e-5,1e-8], 10, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':2, 'K':2, 'tauR':1e3, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'tauR', [1e8,1e5,1e4,1e3,1e2,1e1,1e-3,1e-8], 10, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':2, 'K':10, 'tauR':1e5, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'F', [1,2,3,4,5,10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "params = {'F':2, 'K':4, 'tauR':1e5, 'niter':5, 'algo':'direct'}\n",
    "test_classification(gflc_weights, params, 'K', [2,3,4,5,6,7,8,10,20,30])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
